#! /usr/bin/env python3
###########################################################################
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
###########################################################################
import argparse
import fcntl
import os.path
import re
import selectors
import shlex
import signal
import subprocess
import sys
import time

script_dir = os.path.realpath(os.path.dirname(sys.argv[0]))
top_dir = os.path.dirname(script_dir)
sys.path.append(os.path.join(top_dir, 'py-lib'))
from hex_file import HexReader, HexWriter
from qemu_socket import QemuSocket, LISTEN, CONNECT
from utils import dump_packet, parse_time

class App:
  def __init__(self, argv):
    # The handle for the child process if any.
    self.__proc = None

    # The socket object.
    self.__socket = None

    # The selector object.
    self.__sel = selectors.DefaultSelector()

    # The input packet reader object.
    self.__in_reader = None

    # An iterator over input packets.
    self.__in_iter = None

    # The expected packet reader object.
    self.__exp_reader = None

    # An iterator over expected packets.
    self.__exp_iter = None

    # The next packet to expect, if any.  The packet needs to be taken
    # from the iterator to work out whether the last packet has
    # arrived.  It needs to be stored so that it can be matched when
    # the next packet does arrive.
    self.__exp_packet = None

    # The index of the next packet.
    self.__exp_index = 0

    # The output file, if any.
    self.__out_file = None

    # The current log line from the program.
    self.__log_line = b""

    # A flag indicating whether the test has failed.
    self.__test_failed = False

    # True if the process exited.
    self.__proc_exited = False

    # Set if all packets have been sent.
    self.__send_done = False

    # Set if an EOF has been received from the socket.  This means we
    # should not call the receive method on the socket.
    self.__recv_eop = False

    # Set if excess packets were received.
    self.__recv_excess = False

    # Set if a packet mismatch occurred.
    self.__recv_mismatch = False

    # Set if the test timed out.
    self.__timed_out = False

    # Set when the test is complete.
    self.__complete = False

    # True if the selector is enabled.
    self.__sel_enabled = False

    p = argparse.ArgumentParser()

    p.add_argument("program", type=str, default=None, nargs='?',
                   help="The program to run.")

    p.add_argument("args", metavar="arg", default=[], nargs="*",
                   help="Extra program arguments.")
                   
    p.add_argument("--input", "-i",
                   action="append", default=[],
                   help="Add an input file to process.")

    p.add_argument("--output", "-o",
                   action="store", default=None,
                   help="Set an output file to generate.")

    p.add_argument("--expect", "-e",
                   action="append", default=[],
                   help="Add an expected output file to process.")

    p.add_argument("--host", "-H",
                   action="store", default="127.0.0.1", type=str,
                   help="Set the hostname for the socket.")

    p.add_argument("--port", "-p",
                   action="store", default=0, type=int,
                   help="Set the port number for the socket.")

    p.add_argument("--connect", "-c", dest="mode", default=CONNECT,
                   action="store_const", const=CONNECT,
                   help="Operate the socket in connect mode.")

    p.add_argument("--listen", "-l", dest="mode",
                   action="store_const", const=LISTEN,
                   help="Operate the socket in listen mode.")

    p.add_argument("--standalone", "-s", action="store_true",
                   help="Enable stand-alone mode.")

    p.add_argument("--timeout", "-t", dest="timeout", default="100ms",
                   action="store", type=str,
                   help="Set the timeout (ns/us/ms/s).")

    p.add_argument("--wait-time", "-w", dest="wait_time", default=None,
                   action="store", type=str,
                   help="Set the wait time (ns/us/ms/s).")

    p.add_argument("--mtu", "-m", default=10240,
                   action="store", type=int,
                   help="Set the MTU.")

    p.add_argument("--quiet", "-q", action="count", default=0,
                   help="Produce less output.")

    p.add_argument("--verbose", "-v", action="count", default=0,
                   help="Produce verbose output.")

    self.__args = p.parse_args(argv)

    self.__verbosity = self.__args.verbose - self.__args.quiet
    self.__timeout = parse_time(self.__args.timeout)
    self.__wait_time = parse_time(self.__args.wait_time)
    if self.__wait_time is None:
      self.__wait_time = self.__timeout / 10

    if self.__args.standalone:
      if self.__args.program is not None:
        sys.stderr.write("%s: error: Program name was specified with"
                         " --standalone.  Use --help for help.\n" %
                         (sys.argv[0],))
        sys.exit(1)
    else:
      if self.__args.program is None:
        sys.stderr.write("%s: error: no program name was specified."
                         "  Use --help for help.\n" %
                         (sys.argv[0],))
        sys.exit(1)

      if not os.path.exists(self.__args.program):
        sys.stderr.write("%s: error: Program %r does not exists.\n" %
                         (sys.argv[0], self.__args.program))
        sys.exit(1)

  def get_time(self):
    return time.clock_gettime(time.CLOCK_MONOTONIC)

  def create_iterators(self):
    verbosity = self.__verbosity - 3
    self.__in_reader = HexReader(self.__args.input, verbosity=verbosity)
    self.__in_iter = self.__in_reader.iter_packets()
    self.__exp_reader = HexReader(self.__args.expect, verbosity=verbosity)
    self.__exp_iter = self.__exp_reader.iter_packets()
    self.__exp_packet = next(self.__exp_iter, None)
    self.__exp_index = 0
    if self.__args.output is not None:
      self.__out_file = HexWriter(self.__args.output)

  def create_socket(self):
    verbosity = self.__verbosity - 1
    self.__socket = QemuSocket(hostname=self.__args.host,
                               port=self.__args.port,
                               mode=self.__args.mode,
                               mtu=self.__args.mtu,
                               timeout=self.__timeout,
                               verbosity=verbosity)

    if self.__verbosity >= 1:
      if self.__args.mode == LISTEN:
        host = self.__socket.get_hostname()
        port = self.__socket.get_port()
        sys.stderr.write("Listening on %s port %d.\n" % (host, port))

  def start_test(self):
    time_now = self.get_time()
    self.__start_time = time_now
    self.__end_time = time_now + self.__timeout

    if self.__verbosity >= 1:
      sys.stderr.write("Starting at time %s seconds.\n" % time_now)

  def invoke_program(self):
    if self.__args.standalone:
      return

    mode_map = { LISTEN:CONNECT, CONNECT:LISTEN }
    prog_mode = mode_map[self.__args.mode]
    host = self.__socket.get_hostname()
    port = self.__socket.get_port()
    args = [
      self.__args.program,
      "--socket",
      "%s:%s:%s" % (prog_mode, host, port)
    ] + self.__args.args
    if self.__verbosity >= 0:
      sys.stderr.write("Invoking: %s\n" %
            (" ".join(shlex.quote(s) for s in args)))
    self.__proc = subprocess.Popen(args,
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.STDOUT)

    # Make the pipe for process stdout non-blocking.
    os.set_blocking(self.__proc.stdout.fileno(), False)

    # Register the file descriptor for reading.
    events = selectors.EVENT_READ
    self.__sel.register(self.__proc.stdout, events)

  def wait_proc(self, timeout=0):
    if self.__proc_exited:
      return

    if self.__proc is None:
      return

    try:
      rc = self.__proc.wait(timeout)
      if self.__verbosity >= 4:
        sys.stderr.write("__proc.wait() returned %r.\n" % rc)
    except subprocess.TimeoutExpired:
      if self.__verbosity >= 4:
        sys.stderr.write("__proc.wait() timed out.\n")
      return

    self.__proc_exited = True

    failed = rc != 0 or not self.__complete

    if failed or self.__verbosity >= 1:
      sys.stderr.write("%s: Program exited with code %d.\n" %
                       (sys.argv[0], rc))

    if failed:
      self.__test_failed = True

  def poll_connect(self):
    if not self.__socket.try_connect():
      if self.__verbosity >= 4:
        sys.stderr.write("__socket.try_connect() failed.\n")
      return False

    if self.__verbosity >= 4:
      sys.stderr.write("__socket.try_connect() succeeded.\n")

    host = self.__socket.get_hostname()
    port = self.__socket.get_port()

    if self.__verbosity >= 1:
      if self.__args.mode == CONNECT:
        sys.stderr.write("Connected to %s port %d.\n" % (host, port))
      else:
        sys.stderr.write("Accepted connection.\n")

    self.__sel_enabled = True
    events = selectors.EVENT_READ | selectors.EVENT_WRITE
    self.__sel.register(self.__socket, events)

    return True

  def connect_socket(self):
    if self.__verbosity >= 1:
      sys.stderr.write("Waiting to complete the connection.\n")

    while True:
      time_now = self.get_time()
      self.wait_proc()
      if self.__proc_exited:
        return

      if self.poll_connect():
        return

      if time_now > self.__end_time:
        host = self.__socket.get_hostname()
        port = self.__socket.get_port()
        sys.stderr.write("%s: Failed to establish connection to "
                         "%s port %d.\n" %
                         (sys.argv[0], host, port))
        self.__test_failed = True
        return

      elapsed = time_now - self.__start_time
      if elapsed > 0:
        time.sleep(min(elapsed/4, 1))

  def handle_recv_closed(self):
    self.__recv_eop = True
    self.__test_failed = True

    if self.__send_done:
      self.__sel_enabled = False
    else:
      events = selectors.EVENT_WRITE
      self.__sel.modify(self.__socket, events)

  def compare_packets(self, index, exp_p, recv_p):
    if exp_p == recv_p:
      return

    if self.__recv_mismatch:
      return

    sys.stderr.write("%s: Mismatch at packet index %d.\n" %
                     (sys.argv[0], index))
    dump_packet(sys.stderr, "exp:  ", exp_p)
    dump_packet(sys.stderr, "recv: ", recv_p)
    self.__recv_mismatch = True
    self.__test_failed = True

  def handle_recv_packet(self, p):
    if self.__out_file is not None:
      self.__out_file.add_packet(p)

    if self.__exp_packet is None:
      if ( len(self.__args.expect) != 0 and
           not self.__recv_excess ):
        sys.stderr.write("%s: Too many packets received.\n" %
                         sys.argv[0])
        self.__recv_excess = True
        self.__test_failed = True
      return

    self.compare_packets(self.__exp_index, self.__exp_packet, p)

    self.__exp_packet = next(self.__exp_iter, None)
    self.__exp_index += 1

  def do_socket_recv(self):
    # Receive a batch of packets and process each one.
    for p in self.__socket.receive():
      if p == None:
        sys.stderr.write("%s: EOF from socket.\n" % sys.argv[0])
        self.handle_recv_closed()
        return

      if self.__verbosity >= 3:
        sys.stderr.write("Received packet length %d.\n" % len(p))
        dump_packet(sys.stderr, "received: ", p)
      self.handle_recv_packet(p)

  def handle_send_done(self):
    self.__send_done = True
    if self.__recv_eop:
      self.__sel_enabled = False
    else:
      events = selectors.EVENT_READ
      self.__sel.modify(self.__socket, events)

  def do_socket_send(self, p=None):
    # Send more packets while ready.
    while True:
      # Continue sending the current packet.
      if self.__verbosity >= 3:
        if p is not None:
          dump_packet(sys.stderr, "sending: ", p)
      try:
        if not self.__socket.send(p):
          return
      except BrokenPipeError:
        sys.stderr.write("%s: Send failed with a broken pipe.\n" %
                         sys.argv[0])
        self.__sel_enabled = False
        self.__test_failed = True
        return

      p = next(self.__in_iter, None)
      if p == None:
        self.handle_send_done()
        return

  def start_packets(self):
    # Send the first packet, if any.
    p = next(self.__in_iter, None)
    if p is not None:
      if self.__sel_enabled:
        self.do_socket_send(p)
    else:
      self.handle_send_done()

  def handle_socket_events(self, events):
    if (events & selectors.EVENT_READ) != 0:
      try:
        self.do_socket_recv()

      except ConnectionResetError:
        sys.stderr.write("%s: Connection reset by peer.\n" % sys.argv[0])
        self.handle_recv_closed()

    if (events & selectors.EVENT_WRITE) != 0:
      self.do_socket_send()

  def add_log_data(self, data):
    line = self.__log_line + data
    pos = 0
    while True:
      idx = line.find(b"\n", pos)
      if idx == -1:
        break
      if self.__verbosity >= 0:
        log_str = line[pos:idx].decode('utf-8', errors='replace')
        sys.stderr.write("> %s\n" % log_str)
      pos = idx+1
    self.__log_line = line[pos:]

  def handle_log_event(self, events):
    assert(events == selectors.EVENT_READ)
    data = self.__proc.stdout.read(1024)
    self.add_log_data(data)

  def poll_selector(self):
    timeout = self.__end_time - self.__time_now
    results = self.__sel.select(max(timeout, 0))
    for key,events in results:
      if key.fileobj is self.__socket:
        self.handle_socket_events(events)
      if ( self.__proc is not None and
           key.fileobj is self.__proc.stdout ):
        self.handle_log_event(events)

  def check_running_done(self):
    if ( (self.__exp_packet is None or self.__recv_eop) and
         self.__send_done ):
      return True

    if self.__time_now > self.__end_time:
      if not self.__send_done:
        if self.__exp_packet is not None:
          oper = "Send and receive"
        else:
          oper = "Send"
      else:
        if self.__exp_packet is not None:
          oper = "Receive"
        else:
          oper = "Test"

      sys.stderr.write("%s: %s timed out after %ssec.\n" %
                       (sys.argv[0], oper,
                        self.__time_now - self.__start_time))
      self.__timed_out = True
      self.__test_failed = True
      return True

    if self.__proc_exited:
      return True

    return False

  def check_waiting_done(self):
    if self.__recv_eop:
      return True

    if self.__time_now > self.__end_time:
      return True

    if self.__proc_exited:
      return True

    return False

  def do_running(self):
    if self.__verbosity >= 1:
      sys.stderr.write("Running the test.\n")

    while self.__sel_enabled:
      self.__time_now = self.get_time()
      self.poll_selector()
      self.wait_proc()
      if self.check_running_done():
        return
    
  def do_waiting(self):
    self.__time_now = self.get_time()
    self.__end_time = self.__time_now + self.__wait_time

    if self.__verbosity >= 1:
      sys.stderr.write("Started waiting at time %s seconds.\n" %
                       self.__time_now)

    while self.__sel_enabled:
      self.__time_now = self.get_time()
      self.poll_selector()
      self.wait_proc()
      if self.check_waiting_done():
        return
    
  def end_program(self):
    self.__time_now = self.get_time()

    if self.__verbosity >= 1:
      sys.stderr.write("Finished waiting at time %s seconds.\n" %
                       self.__time_now)

    # Check whether the process has already exited.
    self.wait_proc()
    if self.__proc_exited:
      return

    # Indicate that the process is allowed to exit without an error.
    self.__complete = True

    if self.__proc is None:
      return

    # Try to terminate the process.
    if self.__test_failed:
      if self.__verbosity >= 1:
        sys.stderr.write("Sending terminate signal.\n")
      self.__proc.terminate()
    else:
      if self.__verbosity >= 1:
        sys.stderr.write("Sending interrupt signal.\n")
      self.__proc.send_signal(signal.SIGINT)
    self.wait_proc(self.__wait_time)
    if self.__proc_exited:
      return

    # Kill the process and wait for it to exit.
    sys.stderr.write("Sending kill signal.\n")
    self.__proc.kill()
    self.wait_proc(None)
    assert(self.__proc_exited)

  def flush_output(self):
    if self.__out_file is not None:
      self.__out_file.close()

  def flush_log(self):
    if self.__proc is None:
      return

    # The read method is being called without select indicating that
    # the FD is ready, so re-enable blocking.
    os.set_blocking(self.__proc.stdout.fileno(), True)

    # Read remaining data.
    while True:
      data = self.__proc.stdout.read(1024)
      if len(data) == 0:
        break
      self.add_log_data(data)

  def report_results(self):
    if self.__verbosity >= 1:
      sys.stderr.write("Finished test at time %s seconds.\n" %
                       self.get_time())

    if self.__test_failed:
      print("Test failed.")
      sys.exit(1)

    print("Test passed!")
    sys.exit(0)

  def run(self):
    self.create_iterators()
    self.create_socket()
    self.start_test()
    self.invoke_program()
    self.connect_socket()
    self.start_packets()
    self.do_running()
    self.do_waiting()
    self.end_program()
    self.flush_output()
    self.flush_log()
    self.report_results()

App(sys.argv[1:]).run()
